#!/bin/env python

# Copyright 2005,2008,2012  Sony Corporation

# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# 1. Redistributions of source code must retain the above copyright
#    notice, this list of conditions, and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright
#    notice, this list of conditions and the following disclaimer in the
#    documentation and/or other materials provided with the distribution.

# THIS SOFTWARE IS PROVIDED BY THE REGENTS AND CONTRIBUTORS ``AS IS'' AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
# OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
# OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
# SUCH DAMAGE.

import signal, sys, os, shutil, re, bisect, glob
from itertools import *
from optparse import OptionParser
import subprocess

crsym_version = 'v1.0'

######################################################################
# constants
pat_hex32 = r'(?:0x)?[0-9a-fA-F]{8}'
cpat_hex32 = re.compile(pat_hex32)

######################################################################
# globals
options = None
guess_arch = None
guess_readelf = None

######################################################################
# main
def main():
    global options
    (parser, options, args) = do_option()
    infile = None
    if args:
        infile_path = args.pop(-1)
        if infile_path == '-':
            infile = sys.stdin
        else:
            infile = open(infile_path)
    name_list = gen_name_list(args)
    mem_map = MemoryMap()
    if infile:
        in_lines = read_crash_report(infile, name_list, mem_map)
    if options.print_guess_arch:
        print_guess_arch()
        sys.exit(0)
    if options.print_guess_readelf:
        print get_readelf()
        sys.exit(0)
    if options.print_name_map:
        print_name_map(name_list)
    if options.print_mem_map:
        print str(mem_map)
    if options.print_name_map or options.print_mem_map:
        sys.exit(0)
    if not infile:
        die("no input file")
    name_map = name_list_to_name_map(name_list)
    conv_lines(in_lines, True, mem_map, name_map)
    return

######################################################################
# options
def do_option(in_args = None):
    if in_args == None:
        in_args = sys.argv[1:]
    usage = '''usage: %prog [options] [ELFfile1...] <crash_report>'''
    parser = OptionParser(usage = usage, version = crsym_version)
    parser.set_defaults(em = True,
                        )
    parser.add_option("--search-dir", dest="search_dir",
                                   action="append", metavar="DIR",
                                   help="search DIR for elf files")
    parser.add_option("--name-map", dest="name_map",
                                   metavar="FILE",
                                   help="use specified name->elf-file table")
    parser.add_option("--readelf", dest="readelf",
                                   help="use specified readelf "
                      "(default: find ARCH*readelf in PATH, "
                      "where ARCH is guessed from input)")
    parser.add_option("--print-name-map", dest="print_name_map",
                                   action="store_true",
                                   help="print name->elf-file table")
    parser.add_option("--print-mem-map", dest="print_mem_map",
                                   action="store_true",
                                   help="print table used to map addresses to symbols")
    parser.add_option("--print-guess-arch", dest="print_guess_arch",
                                   action="store_true",
                                   help="print result of guessing architecture")
    parser.add_option("--print-guess-readelf", dest="print_guess_readelf",
                                   action="store_true",
                                   help="print result of guessing readelf path")
    parser.add_option("--verbose", dest="verbose",
                                   action="store_true",
                                   help="enable debug messages")
    #parser.add_option("--em", dest="em",
    #                          action="store_true",
    #                          help="input is em's output (default)")
    #parser.add_option("--no-em", dest="em",
    #                             action="store_false",
    #                             help="input is not em's output.  need --map")
    (options, args) = parser.parse_args(in_args)
    return (parser, options, args)

######################################################################
# classes
class Symbol:
    def __init__(self, start, name):
        self.start = normalize_num(start)
        self.name = name
    def __repr__(self):
        """
        >>> repr(Symbol(123, 'abc'))
        "Symbol(123L, 'abc')"
        """
        return "Symbol" + repr((self.start, self.name))
    def __cmp__(self, other):
        """
        >>> Symbol(1, '') <  Symbol(2, '')
        True
        """
        return cmp(self.start, other.start)

class Range:
    def __init__(self, start, end, offset, name, syms):
        self.start = normalize_num(start)
        self.end = normalize_num(end)
        self.offset = normalize_num(offset)
        self.name = name
        self.syms = syms
    def __cmp__(self, other):
        """
        >>> Range(1, 2, 0, '', []) <  Range(2, 3, 0, '', [])
        True
        """
        return cmp(self.start, other.start)

class MemoryMap:
    def __init__(self):
        self.ranges = []

    def add_range(self, range):
        # FIXTHIS: check for overlapping ranges could be added here
	# but probably it is not needed...
        bisect.insort_right(self.ranges, range)

    def __str__(self):
        ans = ''
        for range in self.ranges:
            ans += "0x%08x-0x%08x: 0x%08x: %s\n" % (range.start, range.end,
                                                    range.offset, range.name)
        return ans.rstrip()

    def search(self, addr, name_map):
        range = self.search_range(addr)
        if not range:
            return None
        if range.syms == None:
            range.syms = read_elf(range.name, range, name_map)
        sym = self.search_sym(range, addr)
        if not sym:
            return None
        return (range, sym)

    def search_range(self, addr):
        addr = normalize_num(addr)
        target = Range(addr, 0, 0, 0, None)
        idx = bisect.bisect_right(self.ranges, target)
        if idx == 0:
            return None
        cand = self.ranges[idx - 1]
        if cand.end <= addr:
            return None
        return cand

    def search_sym(self, range, addr):
        idx = bisect.bisect_right(range.syms, Symbol(addr, ''))
        if idx == 0:
            return None
        return range.syms[idx - 1]

######################################################################
# subs

# --------------------------------------------------------------------
# *converting lines

# skip leftmost addr in [system map], [disassemble], [stack dump]
pat_skip = re.compile(r'^(%s-%s|%s)' % (pat_hex32, pat_hex32,
                                           pat_hex32))
def conv_line(line, skip_leftaddr, mem_map, name_map):
    print line,
    if skip_leftaddr:
        match = pat_skip.search(line)
        if match:
            line = line[match.end():]
    while True:
        match = cpat_hex32.search(line)
        if not match:
            break
        hex = match.group(0)
        line = line[match.end():]
        ans = mem_map.search(hex, name_map)
        if not ans:
            continue
        (module, sym) = ans
        addr = normalize_num(hex)
        print "   +0x%08x: %s+0x%x : %s+0x%x" % (addr,
                                                 sym.name,
                                                 addr - sym.start,
                                                 module.name,
                                                 addr - module.start)

def conv_lines(lines, skip_leftaddr, mem_map, name_map):
    for line in lines:
        conv_line(line, skip_leftaddr, mem_map, name_map),

# ----------------------------------------------------------------------
# *parsing em log

def delimit_lines(pattern, lines):
    """insert None before pattern
    >>> pat = re.compile('=')
    >>> list(delimit_lines(pat, ['a', '=', 'b']))
    ['a', None, '=', 'b']
    >>> list(delimit_lines(pat, ['=', '=']))
    [None, '=', None, '=']
    """
    for line in lines:
        if pattern.search(line):
            yield None
        yield line

def group_lines_by_None(lines):
    """generate list of list.  each sublist is originally separated by None.
    >>> list(group_lines_by_None(['a', 'b', None, 'c', 'd']))
    [['a', 'b'], ['c', 'd']]
    >>> list(group_lines_by_None(['a', None, None, 'b']))
    [['a'], ['b']]
    >>> list(group_lines_by_None([None]))
    []
    """
    itr = iter(lines)
    for line in itr:
        if line:
            group = [line]
            for line in itr:
                if not line:
                    break
                group.append(line)
            yield group

def group_lines(pattern, lines):
    """generate list of list.  each sublist is a line group started by pattern
    in original list.
    >>> pat = re.compile('=')
    >>> list(group_lines(pat, ['a', '=', 'b']))
    [['a'], ['=', 'b']]
    >>> list(group_lines(pat, []))
    []
    """
    return group_lines_by_None(delimit_lines(pattern, lines))

def find_section(name, sections):
    for sec in sections:
        if sec and re.search(name, sec[0]):
            return sec
    return None

pat_mem_map_line = re.compile("^ (%s)-(%s) r-xp (%s) ..:.. [0-9]+\s+(\S+)"
                             % (pat_hex32, pat_hex32, pat_hex32))
def read_crash_memory_map(lines, mem_map):
    for line in lines:
        match = pat_mem_map_line.search(line)
        if match:
            (start, end, offset, name) = match.groups()
            if "/" in name:
                name = os.path.basename(name)
            mem_map.add_range(Range(start, end, offset, name, None))

def read_one_crash_report(sections, name_list, mem_map):
    sec_memory_map = find_section('memory maps', sections)
    if sec_memory_map:
        read_crash_memory_map(sec_memory_map, mem_map)
    do_guess_arch(sections)

def read_crash_report(lines, name_list, mem_map):
    # FIXTHIS - we don't have multiple "blocks" (reports) per file, but
    # this extra parsing is harmless.  Leave it here for now.
    pat_separator = re.compile('^' + '=' * 10)
    pat_header = re.compile(r'^\[(.*)\]')
    # split blocks separated by ==========
    blocks = list(group_lines(pat_separator, lines))

    # find last block containing valid section
    ans = None
    blocks.reverse()
    for block in blocks:
        sections = list(group_lines(pat_header, block))
        if find_section('registers', sections):
            if ans:
                warn("input contains duplicate sections.  "
                     "using the last one.")
                break
            else:
                read_one_crash_report(sections, name_list, mem_map)
                ans = block
    if not ans:
        die("empty input")
    return ans

# ----------------------------------------------------------------------
# *creating name map

def do_search_dir(root):
    if not os.path.isdir(root):
        die("not directory: " + root)
    ans = []
    for dir, subdirs, files in os.walk(root):
        for file in files:
            file_path = os.path.join(dir, file)
            if is_elf(file_path):
                ans.append(path_to_pair(file_path))
    return ans

def gen_name_list(args):
    name_list = []
    if options.search_dir:
        for dir in options.search_dir:
            name_list += do_search_dir(dir)
    if args:
        name_list += name_list_from_args(args)
    if options.name_map:
        name_list += read_name_map(options.name_map)
    return name_list

def exclude_comment(seq):
    for line in seq:
        line = line.strip()
        line = re.sub('#.*', '', line) # remove comment
        if line:
            yield line

def read_name_map(path):
    ans = []
    file = open(path)
    for line in exclude_comment(file):
        pair = line.split()
        if len(pair) != 2:
            die("--name-map format error: " + line.strip())
        ans.append(tuple(pair))
    file.close()
    return ans

def print_name_map(name_list):
    for pair in name_list:
        print "%s %s" % pair

def name_list_from_args(args):
    """
    >>> name_list_from_args(['/bin/ls', '/bin/sh'])
    [('ls', '/bin/ls'), ('sh', '/bin/sh')]
    """
    return [ path_to_pair(path) for path in args if is_elf(path) ]

def path_to_pair(path):
    """
    >>> path_to_pair('/dir/file')
    ('file', '/dir/file')
    >>> path_to_pair('/dir/file.ko')
    ('file', '/dir/file.ko')
    """
    base = os.path.basename(path)
    base = re.sub(r'\.ko$', '', base)
    return (base, path)

def name_list_to_name_map(ls):
    ans = {}
    for (name, path) in ls:
        if name in ans:
            vwarn("%s multiply found (%s, %s)" %(name, ans[name], path))
        ans[name] = path
    return ans

# ----------------------------------------------------------------------
# *command helpers
def cmd_output_lines(arg_list):
	# popen2 is deprecated in later Pythons
        #cmd_stdout = os.popen2(arg_list)[1]
	#lines = cmd_stdout.readlines()

	#output = subprocess.check_output(arg_list)
	#return output.splitlines(True)

	# FIXTHIS - subprocess.check_output is not available until Python 2.7
	# I'm using 2.6, so implement my own
	p = subprocess.Popen(arg_list, -1, stdout=subprocess.PIPE)
	cmd_stdoutput = p.communicate()[0]
	return cmd_stdoutput.splitlines(True)

# ----------------------------------------------------------------------
# *reading elf
def is_elf(path):
    """Return true if path is ELF file
    >>> is_elf('/bin/ls')
    True
    >>> is_elf('/bin/sh') # symlink is dereferenced
    True
    >>> is_elf('/etc/rc.local')
    False
    """
    if not os.path.isfile(path):
        return False
    if not os.access(path, os.R_OK):
        return False
    file = open(path)
    magic = file.read(4)
    file.close()
    return magic == '\x7fELF'

def read_elf(name, range, name_map):
    if name not in name_map:
        vwarn("unknown module: %s" % name)
        return []
    elf_file = name_map[name]
    if not is_elf(elf_file):
        vwarn("not elf: %s" % elf)
        return []
    execp = is_exec(elf_file)
    section_to_num = read_elf_sections(elf_file)
    ex_sections = exclude_sections(section_to_num)
    symtab_lines = cmd_output_lines([get_readelf(), '-Ws', elf_file])
    syms = []
    for line in symtab_lines:
        words = line.split()
        if len(words) < 8:
            continue
        (num, addr, size, type, bind, vis, section_idx, name) = words[0:8]
        if type != 'FUNC': # and type != 'OBJECT':
            continue
        if section_idx == 'UND':
            continue
        if int(section_idx) in ex_sections:
            continue
        if not execp:
            addr = normalize_num(addr) + range.start
        syms.append(Symbol(addr, name))
    syms.sort()
    return syms

def exclude_sections(section_to_num):
    """adhoc way to exclude .ko init/exit section, which may overlap
    main .text section.
    """
    ans = []
    for nam in ('.init.text', '.exit.text'):
        if nam in section_to_num:
            ans.append(section_to_num[nam])
    return ans

pat_elf_section_header_lines = re.compile(r'^\s+\[\s*(\d+)\]\s+(\S+)')
def read_elf_sections(elf_file):
    ans = {}
    sections_lines = cmd_output_lines([get_readelf(), '-WS', elf_file])
    for line in sections_lines:
        match = pat_elf_section_header_lines.search(line)
        if match:
            (num_str, name) = match.groups()
            ans[name] = int(num_str)
    return ans

def is_exec(elf_file):
    header = read_elf_header(elf_file)
    return bool(re.search("Type:\s+EXEC", header))

def read_elf_header(elf_file):
    lines = cmd_output_lines([get_readelf(), '-h', elf_file])
    return ''.join(lines)

# ----------------------------------------------------------------------
# guess readelf

def print_guess_arch():
    if guess_arch:
        print guess_arch
    else:
        print "unknown arch"

def get_readelf():
    if options.readelf:
        return options.readelf
    if not guess_arch:
        die("--readelf not specified and guessing arch failed")
    if not guess_readelf:
        die("--readelf not specified and guessing readelf failed")
    return guess_readelf

def do_guess_readelf(arch):
    for dir in os.environ['PATH'].split(':'):
        pattern = os.path.join(dir, arch + '*readelf')
        ans = glob.glob(pattern)
        if ans:
            return ans[0]
    return None

def do_guess_arch(sections):
    global guess_arch, guess_readelf
    sec_regs = find_section('registers', sections)
    if not sec_regs:
        return
    guess_arch=""
    # FIXTHIS - need to verify PPC and MIPS regex in do_guess_arch
    # FIXTHIS - need to add x86 regex in do_guess_arch
    if re.search(r'^\s*r0 ', sec_regs[1]):
        guess_arch = 'arm'
    elif re.search(r'^ 00:', sec_regs[2]):
        guess_arch = 'powerpc'
    elif re.search(r'^ 0: r0:', sec_regs[1]):
        guess_arch = 'mips'
    guess_readelf = do_guess_readelf(guess_arch)

# ----------------------------------------------------------------------
# *misc

def normalize_num(val):
    """cast val to type suitable for MemoryMap
    >>> normalize_num(123)
    123L
    >>> normalize_num(123L)
    123L
    >>> normalize_num("123")
    291L
    >>> normalize_num("0x123")
    291L
    """
    if isinstance(val, basestring):
        return long(val, 16)
    else:
        return long(val)

def warn(str):
    print >>sys.stderr, sys.argv[0] + ": warning: " + str

def verbose(str):
    if options.verbose:
        print >>sys.stderr, sys.argv[0] + ": " + str

def vwarn(str):
    if options.verbose:
        warn(str)

def die(str):
    sys.exit(sys.argv[0] + ': ' + str)

######################################################################
# top level

if __name__ == "__main__" and not "doctest" in locals():
    signal.signal(signal.SIGPIPE, signal.SIG_DFL)
    main()
